from bluepyopt.ephys.responses import TimeVoltageResponse
import numpy as np
from bluepyparallel import init_parallel_factory
import yaml
from functools import partial

from pathlib import Path
import matplotlib.pyplot as plt
import pickle

# from emodel_generalisation.utils import plot_traces
from emodel_generalisation.model.modifiers import synth_soma, synth_axon
from emodel_generalisation.mcmc import load_chains
from emodel_generalisation.mcmc import save_selected_emodels
from emodel_generalisation.utils import get_feature_df
import logging
from itertools import cycle
import json
from matplotlib.backends.backend_pdf import PdfPages
import pandas as pd
from emodel_generalisation.model.access_point import AccessPoint
from emodel_generalisation.model.evaluation import feature_evaluation
from emodel_generalisation.utils import get_combo_hash
from datareuse import Reuse
import extra_features


def plot_traces(trace_df, trace_path="traces", pdf_filename="traces.pdf", xmax=14000):
    """Plot traces from df, with highlighs on rows with trace_highlight = True.

    Args:
        trace_df (DataFrame): contains list of combos with traces to plot
        trace_path (str): path to folder with traces in .pkl
        pdf_filename (str): name of pdf to save
    """
    COLORS = cycle(["r"] + [f"C{i}" for i in range(10)])
    trace_df = trace_df.copy()  # prevents annyoing panda warnings
    if "trace_highlight" not in trace_df.columns:
        trace_df["trace_highlight"] = True
    for index in trace_df.index:
        if trace_df.loc[index, "trace_highlight"]:
            c = next(COLORS)

        if "trace_data" in trace_df.columns:
            trace_path = trace_df.loc[index, "trace_data"]
        else:
            combo_hash = get_combo_hash(trace_df.loc[index])
            trace_path = Path(trace_path) / ("trace_id_" + str(combo_hash) + ".pkl")

        with open(trace_path, "rb") as f:
            trace = pickle.load(f)
            if isinstance(trace, list):
                trace = trace[1]  # newer version the response are here
            for protocol, response in trace.items():
                lw = 0.8
                if protocol == "Step_ReboundBurst_burst.soma.v":
                    plt.figure(protocol, figsize=(5, 3))
                    plt.gca().set_xlim(5500, xmax)
                    plt.plot(
                        response["time"],
                        response["voltage"],
                        # label=label,
                        c="k",
                        lw=lw,
                    )

                plt.xlabel("Time (ms)")
                plt.ylabel("Voltage (mV)")
    with PdfPages(pdf_filename) as pdf:
        for fig_id in plt.get_fignums():
            fig = plt.figure(fig_id)
            plt.legend(loc="best")
            plt.suptitle(fig.get_label())
            plt.tight_layout()
            pdf.savefig()
            plt.close()


if __name__ == "__main__":
    parallel_factory = init_parallel_factory("multiprocessing")
    # parallel_factory = init_parallel_factory("dask_dataframe")
    emodel = "simplest"
    logger = logging.getLogger()
    v = 1
    logging.basicConfig(
        level=(logging.WARNING, logging.INFO, logging.DEBUG)[v],
        handlers=[logging.StreamHandler()],
    )
    logger.setLevel((logging.WARNING, logging.INFO, logging.DEBUG)[v])

    access_point = AccessPoint(
        emodel_dir=".",
        recipes_path="config/recipes.json",
        final_path="final.json",
        with_seeds=True,
    )
    exemplar_data = yaml.safe_load(open("../../../mcmc_run/exemplar_data.yaml"))
    access_point.morph_path = exemplar_data["paths"]["all"]
    access_point.settings["morph_modifiers"] = [
        partial(synth_soma, params=exemplar_data["soma"], scale=1.0),
        partial(synth_axon, params=exemplar_data["ais"]["popt"], scale=1.0),
    ]
    df = pd.DataFrame()
    iahps_basal = [0.0005, 0.005, 0.01, 0.015, 0.02, 0.025]
    for i, (iahp_basal) in enumerate(iahps_basal):
        df.loc[i, "emodel"] = "simplest_spp"
        df.loc[i, "name"] = f"exemplar_{i}"
        df.loc[i, "new_parameters"] = json.dumps({"gkbar_iahp.basal": iahp_basal})
    print(df)
    Path("traces").mkdir(exist_ok=True)
    with Reuse("eval_scan_iahp.csv") as reuse:
        df = reuse(
            feature_evaluation,
            df,
            access_point,
            parallel_factory=parallel_factory,
            trace_data_path="traces",
            # record_ions_and_currents=True,
        )
    plot_traces(df.loc[[0]], pdf_filename="traces_scan_first.pdf", xmax=7500)
    plot_traces(df.loc[[2]], pdf_filename="traces_scan_middle.pdf", xmax=7500)
    plot_traces(df.loc[[4]], pdf_filename="traces_scan_last.pdf", xmax=7500)
    plot_traces(df.loc[[0]], pdf_filename="traces_scan_first_zoom.pdf", xmax=5650)
    plot_traces(df.loc[[2]], pdf_filename="traces_scan_middle_zoom.pdf", xmax=5650)
    plot_traces(df.loc[[4]], pdf_filename="traces_scan_last_zoom.pdf", xmax=5650)

    plt.close("all")

    d = get_feature_df(df)
    plt.figure(figsize=(4, 3))
    plt.plot(iahps_basal, d["Step_ReboundBurst_burst.soma.v.spikes_per_burst"], "+-", c="k")
    plt.xlabel("gkbar_iahp.basal")
    plt.ylabel("spike per burst")
    plt.axis([0, 0.026, 3, 7.2])
    plt.tight_layout()
    plt.savefig("scan_iahp.pdf")
    plt.figure(figsize=(4, 3))
    plt.plot(iahps_basal, d["Step_ReboundBurst_burst.soma.v.postburst_min_values"], "+-", c="k")
    plt.xlabel("gkbar_iahp.basal")
    plt.ylabel("postburst min values")
    plt.axis([0, 0.026, -90, -70])
    plt.tight_layout()
    plt.savefig("scan_iahp_postburst.pdf")

    plt.close()
